"""Runner for on-policy MA algorithms."""
import numpy as np
import torch
from harl.runners.on_policy_base_runner import OnPolicyBaseRunner


class OnPolicyMARunner(OnPolicyBaseRunner):
    """Runner for on-policy MA algorithms."""

    def train(self):
        """Training procedure for MAPPO."""
        # actor_train_infos = []
        actor_train_infos = [None] * self.num_agents

        # compute advantages
        if self.value_normalizer is not None:
            advantages = self.critic_buffer.returns[:-1] - self.value_normalizer.denormalize(self.critic_buffer.value_preds[:-1])
        else:
            advantages = (self.critic_buffer.returns[:-1] - self.critic_buffer.value_preds[:-1])

        # normalize advantages for FP
        if self.state_type == "FP":
            active_masks_collector = [self.actor_buffer[i].active_masks for i in range(self.num_agents)]
            active_masks_array = np.stack(active_masks_collector, axis=2)
            advantages_copy = advantages.copy()
            advantages_copy[active_masks_array[:-1] == 0.0] = np.nan
            mean_advantages = np.nanmean(advantages_copy)
            std_advantages = np.nanstd(advantages_copy)
            advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)

        # update actors
        if self.share_param:
            actor_train_info = self.actor[0].share_param_train(self.actor_buffer, advantages.copy(), self.num_agents, self.state_type)
            for i in torch.randperm(self.num_agents):
                # actor_train_infos.append(actor_train_info)
                actor_train_infos[i] = actor_train_info
        else:
            for agent_id in range(self.num_agents):

                if agent_id == self.algo_args["train"]['fix_agent_id']:
                    continue

                if self.state_type == "EP": # here
                    actor_train_info = self.actor[agent_id].train(self.actor_buffer[agent_id], advantages.copy(), "EP")
                elif self.state_type == "FP":
                    actor_train_info = self.actor[agent_id].train(self.actor_buffer[agent_id], advantages[:, :, agent_id].copy(), "FP")
                # actor_train_infos.append(actor_train_info)
                actor_train_infos[agent_id] = actor_train_info

        if actor_train_infos[0] is None:
            actor_train_infos[0] = actor_train_infos[1]
        if actor_train_infos[1] is None:
            actor_train_infos[1] = actor_train_infos[0]

        # update critic
        critic_train_info = self.critic.train(self.critic_buffer, self.value_normalizer)

        return actor_train_infos, critic_train_info
